{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Definition of some constants\n",
    "| Item          | Data Type             | Description |\n",
    "| :-----------: | :-------------------: | :---------: |\n",
    "| data_root     | String                | The root directory to save images |\n",
    "| split_radio   | Tuple of 3 elements   | The ratio of the total dataset to use for training, validation and test respectively |\n",
    "| min_num_images_per_class | Integer | Classes with fewer images will be removed from the whole dataset |\n",
    "| min_num_train_images_per_class | Integer | Minimal number of images per class in the training set, to guarantee training performance. **Ensure that `min_num_train_images_per_class` $\\le$ `min_num_images_per_class`** |\n",
    "| random_seed | Integer | The random seed | "
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "outputs": [],
   "source": [
    "data_root = os.path.expanduser('~/vggface3d_sm')\n",
    "assert os.path.exists(data_root), \\\n",
    "    'Dataset directory not found: %s' % data_root\n",
    "\n",
    "split_radio = (0.8, 0.1, 0.1)\n",
    "assert sum(split_radio) == 1\n",
    "min_num_images_per_class = 10\n",
    "min_num_train_images_per_class = 5\n",
    "\n",
    "random_seed = 2333\n",
    "np.random.seed(random_seed)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Generating CSV Files\n",
    "CSVs are saved in the `data_root` directory, i.e., `data_root/train.csv`, `data_root/eval.csv`, `data_root/test.csv`.\n",
    "### Overview of DataRoot\n",
    "```text\n",
    "vggface3d_sm\n",
    "|── train.csv\n",
    "|── eval.csv\n",
    "|── test.csv\n",
    "|── dirty.csv\n",
    "├── n000853\n",
    "│   ├── 0001_03.npy\n",
    "│   ├── 0001_03.png\n",
    "│   ├── 0002_01.npy\n",
    "│   ├── 0002_01.png\n",
    "│   ├── 0003_01.npy\n",
    "│   ├── 0003_01.png\n",
    "│   ├── 0004_01.npy\n",
    "│   ├── ......\n",
    "```"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "outputs": [
    {
     "name": "stdout",
     "text": [
      "everything finished!\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "def _get_data_of_one_class(_cls_name, shuffle=True):\n",
    "    _cls_dir = os.path.realpath(os.path.join(data_root, _cls_name))\n",
    "    assert os.path.exists(_cls_dir) and os.path.isdir(_cls_dir)\n",
    "    _data = []\n",
    "    for file in os.listdir(_cls_dir):\n",
    "        file_name, file_ext = os.path.splitext(file)\n",
    "        if file_ext != '.png':\n",
    "            continue\n",
    "        if not os.path.exists(os.path.join(_cls_dir, '%s.npy' % file_name)):\n",
    "            print(\"%s exists but %s can not be found\" % (\n",
    "                os.path.join(_cls_dir, '%s.png' % file_name), \n",
    "                os.path.join(_cls_dir, '%s.npy' % file_name)))\n",
    "            continue\n",
    "        _data.append([os.path.join(_cls_dir, '%s.png' % file_name), \n",
    "                      os.path.join(_cls_dir, '%s.npy' % file_name),\n",
    "                      _cls_name])\n",
    "    if shuffle:\n",
    "        np.random.shuffle(_data)\n",
    "    return _data\n",
    "\n",
    "\n",
    "# train_data, eval_data, test_data, dirty_data = [[]] * 4\n",
    "train_data, eval_data, test_data, dirty_data = [], [], [], []\n",
    "\n",
    "for cls_name in os.listdir(data_root):\n",
    "    # Skip if it is not a folder\n",
    "    if not os.path.isdir(os.path.join(data_root, cls_name)):\n",
    "        continue\n",
    "    cls_data = _get_data_of_one_class(cls_name, shuffle=True)\n",
    "    if len(cls_data) < min_num_images_per_class:\n",
    "        dirty_data.extend(cls_data)\n",
    "        continue\n",
    "    \n",
    "    num_train_images = max(min_num_train_images_per_class, \n",
    "                           int(len(cls_data) * split_radio[0]))\n",
    "    num_eval_images = int((len(cls_data) - num_train_images) * \n",
    "                          split_radio[1]/(1-split_radio[0]))\n",
    "    num_test_images = len(cls_data) - num_train_images - num_eval_images\n",
    "    \n",
    "    train_data.extend(cls_data[:num_train_images])\n",
    "    eval_data.extend(cls_data[num_train_images:num_train_images+num_eval_images])\n",
    "    test_data.extend(cls_data[-num_test_images:])\n",
    "\n",
    "train_data = np.array(train_data)\n",
    "train_df = pd.DataFrame({'rgb_image_path' : train_data[:, 0] if len(train_data)>0 else [], \n",
    "                         'dep_image_path' : train_data[:, 1] if len(train_data)>0 else [], \n",
    "                         'cls_name': train_data[:, 2] if len(train_data)>0 else []})\n",
    "train_df.to_csv(os.path.join(data_root, 'train.csv'), index=False, sep=',')\n",
    "\n",
    "eval_data = np.array(eval_data)\n",
    "eval_df = pd.DataFrame({'rgb_image_path' : eval_data[:, 0] if len(eval_data)>0 else [], \n",
    "                        'dep_image_path' : eval_data[:, 1] if len(eval_data)>0 else [], \n",
    "                        'cls_name': eval_data[:, 2] if len(eval_data)>0 else []})\n",
    "eval_df.to_csv(os.path.join(data_root, 'eval.csv'), index=False, sep=',')\n",
    "\n",
    "test_data = np.array(test_data)\n",
    "test_df = pd.DataFrame({'rgb_image_path' : test_data[:, 0] if len(test_data)>0 else [], \n",
    "                        'dep_image_path' : test_data[:, 1] if len(test_data)>0 else [], \n",
    "                        'cls_name': test_data[:, 2] if len(test_data)>0 else []})\n",
    "test_df.to_csv(os.path.join(data_root, 'test.csv'), index=False, sep=',')\n",
    "\n",
    "dirty_data = np.array(dirty_data)\n",
    "dirty_df = pd.DataFrame({'rgb_image_path' : dirty_data[:, 0] if len(dirty_data)>0 else [], \n",
    "                         'dep_image_path' : dirty_data[:, 1] if len(dirty_data)>0 else [], \n",
    "                         'cls_name': dirty_data[:, 2] if len(dirty_data)>0 else []})\n",
    "dirty_df.to_csv(os.path.join(data_root, 'dirty.csv'), index=False, sep=',')\n",
    "print(\"everything finished!\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "outputs": [
    {
     "data": {
      "text/plain": "                                      rgb_image_path  \\\n0  /Users/xingwxiong/vggface3d_sm/n000910/0111_01...   \n1  /Users/xingwxiong/vggface3d_sm/n000910/0301_01...   \n2  /Users/xingwxiong/vggface3d_sm/n000910/0251_03...   \n3  /Users/xingwxiong/vggface3d_sm/n000910/0031_01...   \n4  /Users/xingwxiong/vggface3d_sm/n000910/0241_01...   \n\n                                      dep_image_path cls_name  \n0  /Users/xingwxiong/vggface3d_sm/n000910/0111_01...  n000910  \n1  /Users/xingwxiong/vggface3d_sm/n000910/0301_01...  n000910  \n2  /Users/xingwxiong/vggface3d_sm/n000910/0251_03...  n000910  \n3  /Users/xingwxiong/vggface3d_sm/n000910/0031_01...  n000910  \n4  /Users/xingwxiong/vggface3d_sm/n000910/0241_01...  n000910  ",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>rgb_image_path</th>\n      <th>dep_image_path</th>\n      <th>cls_name</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>/Users/xingwxiong/vggface3d_sm/n000910/0111_01...</td>\n      <td>/Users/xingwxiong/vggface3d_sm/n000910/0111_01...</td>\n      <td>n000910</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>/Users/xingwxiong/vggface3d_sm/n000910/0301_01...</td>\n      <td>/Users/xingwxiong/vggface3d_sm/n000910/0301_01...</td>\n      <td>n000910</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>/Users/xingwxiong/vggface3d_sm/n000910/0251_03...</td>\n      <td>/Users/xingwxiong/vggface3d_sm/n000910/0251_03...</td>\n      <td>n000910</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>/Users/xingwxiong/vggface3d_sm/n000910/0031_01...</td>\n      <td>/Users/xingwxiong/vggface3d_sm/n000910/0031_01...</td>\n      <td>n000910</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>/Users/xingwxiong/vggface3d_sm/n000910/0241_01...</td>\n      <td>/Users/xingwxiong/vggface3d_sm/n000910/0241_01...</td>\n      <td>n000910</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {},
     "output_type": "execute_result",
     "execution_count": 30
    }
   ],
   "source": [
    "train_df.head()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "outputs": [],
   "source": [
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}